module SvmToolkit

  # Extends the Java SVM class
  #
  # Available methods include:
  # 
  # Svm.svm_train(problem, param)
  # 
  # problem:: instance of Problem
  # param:: instance of Parameter
  #
  # Returns an instance of Model
  #
  # Svm.svm_cross_validation(problem, param, nr_folds, target)
  #
  # problem:: instance of Problem
  # param:: instance of Parameter
  # nr_fold:: number of folds
  # target:: resulting predictions in an array
  #
  class Svm

    # Perform cross validation search on given gamma/cost values, 
    # using an RBF kernel, 
    # returning the best performing model and optionally displaying 
    # a contour map of performance.
    #
    # training_set::   instance of Problem, used for training
    # cross_valn_set:: instance of Problem, used for evaluating models
    # costs::          array of cost values to search across
    # gammas::         array of gamma values to search across
    # params::         Optional parameters include:
    # * :evaluator => Evaluator::OverallAccuracy, the name of the class 
    #   to use for computing performance
    # * :show_plot => false, whether to display contour plot
    #
    # Returns an instance of Model, the best performing model.
    #
    def Svm.cross_validation_search(training_set, cross_valn_set, 
                                     costs = [-2,-1,0,1,2,3].collect {|i| 2**i}, 
                                     gammas = [-2,-1,0,1,2,3].collect {|i| 2**i}, 
                                     params = {})
      evaluator = params.fetch :evaluator, Evaluator::OverallAccuracy
      show_plot = params.fetch :show_plot, false

      fjp = ForkJoinPool.new
      task = CrossValidationSearch.new gammas, costs, training_set, cross_valn_set, evaluator
      results, best_model = fjp.invoke task

      if show_plot
        ContourDisplay.new(costs.collect {|n| Math.log2(n)}, 
                           gammas.collect {|n| Math.log2(n)}, 
                           results)
      end

      return best_model
    end

    private
    # Set up the cross validation search across a cost/gamma pair
    class CrossValidationSearch < RecursiveTask
      # Creates an instance of the CrossValidationSearch.
      #
      # gammas:: array of gamma values to search over
      # costs:: array of cost values to search over
      # training_set:: for building the model
      # cross_valn_set:: for testing the model
      # evaluator:: name of Evaluator class, used for evaluating the model
      #
      def initialize gammas, costs, training_set, cross_valn_set, evaluator
        super() 

        @gammas = gammas
        @costs = costs
        @training_set = training_set
        @cross_valn_set = cross_valn_set  
        @evaluator = evaluator
      end

      # perform actual computation, return results/best_model
      def compute
        tasks = []
        # create one task per gamma/cost pair
        @gammas.each do |gamma|
          @costs.each do |cost|
            tasks << SvmTrainer.new(@training_set, Parameter.new(
              :svm_type => Parameter::C_SVC,
              :kernel_type => Parameter::RBF,
              :cost => cost,
              :gamma => gamma
            ), @cross_valn_set, @evaluator)
          end
        end

        # set off all the tasks
        tasks.each do |task|
          task.fork
        end

        # collect the results
        results = []
        best_model = nil
        lowest_error = nil

        @gammas.each do |gamma|
          results_row = []
          @costs.each do |cost|
            task = tasks.shift
            model, result = task.join

            if result.better_than? lowest_error
              best_model = model
              lowest_error = result
            end
            puts "Result for cost = #{cost}  gamma = #{gamma} is #{result.value}"
            results_row << result.value
          end
          results << results_row
        end

        return results, best_model
      end
    end

    # Represent a single training task for an SVM RBF model
    class SvmTrainer < RecursiveTask

      # Creates an instance of an SvmTrainer.
      #
      # training_set:: used to train the model
      # parameters:: parameters for building the model
      # cross_valn_set:: used to test the model performance
      # evaluator:: class name of Evaluator to use for evaluating the model performance
      #
      def initialize training_set, parameters, cross_valn_set, evaluator
        super()

        @training_set = training_set
        @parameters = parameters
        @cross_valn_set = cross_valn_set
        @evaluator = evaluator
      end

      # Trains and evaluates a model, using the parameters.
      #
      # Returns the model and evaluation.
      #
      def compute
        model = Svm.svm_train @training_set, @parameters
        result = model.evaluate_dataset @cross_valn_set, :evaluator => @evaluator
        return model, result
      end
    end

    # Swing Frame displaying the cross-validation performance.
    #
    class ContourDisplay < javax.swing.JFrame
      # Creates an instance of the ContourDisplay.
      #
      # xs:: array of x-coordinates
      # ys:: array of y-coordinates
      # zs:: array of values for matching (x, y) coordinate
      #
      def initialize(xs, ys, zs)
        super("Cross-Validation Performance")
        self.setSize(500, 400)

        cxs = Java::double[][ys.size].new
        cys = Java::double[][ys.size].new
        ys.size.times do |i|
          cxs[i] = Java::double[xs.size].new
          cys[i] = Java::double[xs.size].new
          xs.size.times do |j|
            cxs[i][j] = xs[j]
            cys[i][j] = ys[i]
          end
        end

        czs = Java::double[][ys.size].new
        ys.size.times do |i|
          czs[i] = Java::double[xs.size].new
          xs.size.times do |j|
            czs[i][j] = zs[i][j]
          end
        end

        plot = ContourPlot.new(
          cxs,
          cys,
          czs,
          10,
          false,
          "",
          "Cost (log-scale)",
          "Gamma (log-scale)",
          nil,
          nil
        )
        plot.colorizeContours(java.awt::Color.green, java.awt::Color.red)

        symbol = DiamondSymbol.new
        symbol.border_color = java.awt::Color.blue
        symbol.fill_color = java.awt::Color.blue
        symbol.size = 4

        run = PlotRun.new
        ys.size.times do |i|
          xs.size.times do |j|
            run.add(PlotDatum.new(cxs[i][j], cys[i][j], false, symbol))
          end
        end

        plot.runs << run

        panel = PlotPanel.new(plot)
        panel.background = java.awt::Color.white
        add panel

        self.setDefaultCloseOperation(javax.swing.WindowConstants::DISPOSE_ON_CLOSE)
        self.visible = true
      end
    end
  end
end
